#include "PCH.h"
#include "DxbcContainer.h"

#include "DxbcChunk.h"
#include "InputSignatureChunk.h"
#include "InterfacesChunk.h"
#include "OutputSignatureChunk.h"
#include "PatchConstantSignatureChunk.h"
#include "ResourceDefinitionChunk.h"
#include "ShaderProgramChunk.h"
#include "StatisticsChunk.h"

using namespace boolinq;
using namespace std;
using namespace SlimShader;

DxbcContainer DxbcContainer::Parse(const vector<char> bytes)
{
	const uint8_t* bytesPointer = reinterpret_cast<const uint8_t*>(&bytes[0]);
	return Parse(BytecodeReader(bytesPointer, bytes.size()));
}

DxbcContainer DxbcContainer::Parse(BytecodeReader& reader)
{
	DxbcContainer container;

	BytecodeReader headerReader(reader);
	container._header = DxbcContainerHeader::Parse(headerReader);

	for (uint32_t i = 0; i < container._header.GetChunkCount(); i++)
	{
		auto chunkOffset = headerReader.ReadUInt32();
		auto chunkReader = reader.CopyAtOffset(chunkOffset); 
		container._chunks.push_back(DxbcChunk::Parse(chunkReader, container));
	}

	return container;
}

template <class T>
const shared_ptr<T> FindChunk(vector<shared_ptr<DxbcChunk>> chunks, ChunkType type1, ChunkType type2 = ChunkType::Unknown)
{
	// Not quite as nice as chunks.OfType<ResourceDefinitionChunk>().FirstOrDefault(), but never mind...
	auto matchingChunks = from(chunks)
		.where([type1, type2](shared_ptr<DxbcChunk> chunk) { return chunk->GetChunkType() == type1 || chunk->GetChunkType() == type2; })
		.toVector();
	if (!matchingChunks.empty())
		return dynamic_pointer_cast<T>(matchingChunks[0]);
	return nullptr;
}

const shared_ptr<ResourceDefinitionChunk> DxbcContainer::GetResourceDefinition() const
{
	return FindChunk<ResourceDefinitionChunk>(_chunks, ChunkType::Rdef);
}

const std::shared_ptr<PatchConstantSignatureChunk> DxbcContainer::GetPatchConstantSignature() const
{
	return FindChunk<PatchConstantSignatureChunk>(_chunks, ChunkType::Pcsg);
}

const std::shared_ptr<InputSignatureChunk> DxbcContainer::GetInputSignature() const
{
	return FindChunk<InputSignatureChunk>(_chunks, ChunkType::Isgn);
}

const std::shared_ptr<OutputSignatureChunk> DxbcContainer::GetOutputSignature() const
{
	return FindChunk<OutputSignatureChunk>(_chunks, ChunkType::Osgn, ChunkType::Osg5);
}

const std::shared_ptr<ShaderProgramChunk> DxbcContainer::GetShader() const
{
	return FindChunk<ShaderProgramChunk>(_chunks, ChunkType::Shdr, ChunkType::Shex);
}

const std::shared_ptr<StatisticsChunk> DxbcContainer::GetStatistics() const
{
	return FindChunk<StatisticsChunk>(_chunks, ChunkType::Stat);
}

const std::shared_ptr<InterfacesChunk> DxbcContainer::GetInterfaces() const
{
	return FindChunk<InterfacesChunk>(_chunks, ChunkType::Ifce);
}

ostream& SlimShader::operator<<(ostream &out, const DxbcContainer &container)
{
	out << "// " << endl;
	out << "// Generated by SlimShader" << endl;
	out << "// " << endl;
	out << "// " << endl;
	out << "// " << endl;

	out << "//" << endl;
	out << "//" << endl;

	if (container.GetResourceDefinition() != nullptr)
		out << *container.GetResourceDefinition();

	out << "//" << endl;

	if (container.GetPatchConstantSignature() != nullptr)
	{
		out << *container.GetPatchConstantSignature();
		out << "//" << endl;
	}

	out << *container.GetInputSignature();
	out << "//" << endl;

	out << *container.GetOutputSignature();

	if (container.GetStatistics() != nullptr)
		out << *container.GetStatistics();

	if (container.GetInterfaces() != nullptr)
		out << *container.GetInterfaces();

	if (container.GetShader() != nullptr)
		out << *container.GetShader();

	out << "// Approximately " << container.GetStatistics()->GetInstructionCount() << " instruction slots used";

	return out;
}